Extending mixed-effects models by allowing the relationship (slope) to vary across different groups.
General Principles
To model the relationship between predictor variables and a dependent variable while allowing for varying effects across groups or clusters, we use a varying slopes model.
This approach is useful when we expect the relationship between predictors and the dependent variable to differ across groups (e.g., different slopes for different subjects, locations, or time periods). This allows every unit in the data to have its own unique response to any treatment, exposure, or event, while also improving estimates via pooling.
The idea is pretty similar to categorical models, where a slope is specified for each category. However, here, we also estimate relationships between different groups. This leads to a different mathematical approach, as to model these relationships between groups, we model a matrix of covariance 🛈.
To construct the covariance matrix, we use an SRS decomposition where S is a diagonal matrix of standard deviations and R is a correlation matrix. To model the correlation matrix, we use an LKJcorr distribution parametrized with a single control parameter η that controls the amount of regularization. η is usually set to 2 to define a weakly informative prior that is skeptical of extreme correlations near −1 or 1. When we use LKJcorr(1), the prior is flat over all valid correlation matrices. When the value is greater than 1, then extreme correlations are less likely.
The standard deviations in S are model with a prior that constrains them to strictly positive values.
Example
Below is an example code snippet demonstrating Bayesian regression with varying effects. This example is based on McElreath (2018).
/home/sosa/work/3.12venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning:
IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
bf v 0.0.47 package loaded
jax.local_device_count 32
/home/sosa/work/BF/BayesForge/Diagnostic/jax_diagnostics.py:214: RuntimeWarning:
invalid value encountered in scalar divide
mean
sd
hdi_5.5%
hdi_94.5%
mcse_mean
mcse_sd
ess_bulk
ess_tail
r_hat
Rho[0, 0]
1.00
0.00
1.00
1.00
0.00
0.00
14408.24
14408.24
inf
Rho[0, 1]
-0.56
0.22
-0.89
-0.24
0.01
0.00
1752.10
1183.04
1.00
Rho[1, 0]
-0.56
0.22
-0.89
-0.24
0.01
0.00
1752.10
1183.04
1.00
Rho[1, 1]
1.00
0.00
1.00
1.00
0.00
0.00
4021.66
3814.38
1.05
a
3.86
0.26
3.46
4.29
0.00
0.00
4211.85
2808.09
1.00
a_b_cafe[0, 0]
4.23
0.21
3.88
4.56
0.00
0.00
4246.90
2753.36
1.00
a_b_cafe[0, 1]
-1.49
0.23
-1.88
-1.13
0.00
0.00
3533.17
2772.77
1.00
a_b_cafe[1, 0]
0.40
0.23
0.02
0.76
0.00
0.00
2936.09
3106.71
1.00
a_b_cafe[1, 1]
-0.72
0.29
-1.20
-0.26
0.01
0.00
1943.59
2541.81
1.00
a_b_cafe[2, 0]
4.43
0.21
4.06
4.75
0.00
0.00
4660.96
3267.56
1.00
a_b_cafe[2, 1]
-1.40
0.23
-1.75
-1.03
0.00
0.00
4060.15
2632.03
1.00
a_b_cafe[3, 0]
3.64
0.23
3.27
3.99
0.00
0.00
3076.67
2735.60
1.00
a_b_cafe[3, 1]
-1.60
0.27
-2.03
-1.20
0.01
0.00
1829.69
2648.20
1.00
a_b_cafe[4, 0]
3.21
0.22
2.87
3.57
0.00
0.00
3473.14
2807.37
1.00
a_b_cafe[4, 1]
-1.03
0.25
-1.43
-0.64
0.01
0.00
1895.39
2453.51
1.00
a_b_cafe[5, 0]
3.01
0.22
2.66
3.34
0.00
0.00
3541.30
3190.21
1.00
a_b_cafe[5, 1]
-1.39
0.25
-1.78
-0.98
0.00
0.00
2785.63
3036.67
1.00
a_b_cafe[6, 0]
2.87
0.22
2.51
3.22
0.00
0.00
1998.02
2188.31
1.00
a_b_cafe[6, 1]
-0.86
0.27
-1.33
-0.47
0.01
0.01
1211.87
1529.38
1.00
a_b_cafe[7, 0]
3.02
0.22
2.65
3.34
0.00
0.00
4163.17
2692.32
1.00
a_b_cafe[7, 1]
-1.43
0.25
-1.83
-1.05
0.00
0.00
2515.43
2446.61
1.00
a_b_cafe[8, 0]
5.49
0.22
5.14
5.83
0.00
0.00
3705.00
2692.29
1.00
a_b_cafe[8, 1]
-1.49
0.24
-1.88
-1.11
0.00
0.00
3252.27
2745.19
1.00
a_b_cafe[9, 0]
3.46
0.21
3.13
3.80
0.00
0.00
3890.75
2430.43
1.00
a_b_cafe[9, 1]
-1.37
0.23
-1.73
-0.99
0.00
0.00
3595.03
2944.52
1.00
a_b_cafe[10, 0]
3.96
0.21
3.65
4.33
0.00
0.00
3057.81
2108.29
1.00
a_b_cafe[10, 1]
-1.12
0.25
-1.49
-0.72
0.01
0.00
1840.29
2968.54
1.00
a_b_cafe[11, 0]
4.21
0.21
3.91
4.56
0.00
0.00
3333.13
2455.28
1.00
a_b_cafe[11, 1]
-1.31
0.23
-1.68
-0.95
0.00
0.00
3665.55
2985.66
1.00
a_b_cafe[12, 0]
5.39
0.22
5.03
5.73
0.00
0.00
3218.64
2721.53
1.00
a_b_cafe[12, 1]
-1.66
0.25
-2.08
-1.29
0.00
0.00
2545.57
2842.72
1.00
a_b_cafe[13, 0]
3.52
0.21
3.21
3.88
0.00
0.00
4455.94
3129.22
1.00
a_b_cafe[13, 1]
-1.21
0.23
-1.56
-0.84
0.00
0.00
3530.66
2965.95
1.00
a_b_cafe[14, 0]
3.48
0.21
3.15
3.82
0.00
0.00
3597.24
2854.85
1.00
a_b_cafe[14, 1]
-1.14
0.24
-1.50
-0.74
0.00
0.00
2675.32
2497.12
1.00
a_b_cafe[15, 0]
4.30
0.21
3.98
4.64
0.00
0.00
4774.13
3162.85
1.00
a_b_cafe[15, 1]
-1.40
0.23
-1.77
-1.04
0.00
0.00
3984.24
3133.99
1.00
a_b_cafe[16, 0]
4.55
0.21
4.21
4.88
0.00
0.00
2460.98
2889.72
1.00
a_b_cafe[16, 1]
-1.60
0.24
-1.99
-1.21
0.01
0.00
1907.81
2387.54
1.00
a_b_cafe[17, 0]
4.66
0.21
4.35
5.00
0.00
0.00
4086.72
3105.06
1.00
a_b_cafe[17, 1]
-1.35
0.23
-1.69
-0.96
0.00
0.00
3467.95
3220.21
1.00
a_b_cafe[18, 0]
5.42
0.22
5.05
5.75
0.00
0.00
2735.52
2947.42
1.00
a_b_cafe[18, 1]
-1.76
0.26
-2.18
-1.36
0.01
0.00
1839.87
3035.21
1.00
a_b_cafe[19, 0]
3.89
0.21
3.54
4.20
0.00
0.00
4067.57
3290.06
1.00
a_b_cafe[19, 1]
-1.20
0.24
-1.55
-0.81
0.00
0.00
2978.96
2725.84
1.00
b
-1.32
0.11
-1.50
-1.15
0.00
0.00
2452.25
2609.21
1.00
sigma
0.54
0.03
0.49
0.59
0.00
0.00
3524.04
2303.50
1.00
sigma_cafe[0]
1.17
0.19
0.89
1.45
0.00
0.00
4369.97
2984.81
1.00
sigma_cafe[1]
0.35
0.12
0.16
0.53
0.00
0.00
672.14
791.70
1.00
Code
from BayesForge import bfimport jax.numpy as jnp# Setup device------------------------------------------------m = bf(platform="cpu")# Import Data & Data Manipulation ------------------------------------------------data_path = m.load.sim_multivariate_normal(only_path=True)m.data(data_path, sep=",")m.data_on_model =dict( cafe=jnp.array(m.df.cafe.values, dtype=jnp.int32), wait=jnp.array(m.df.wait.values, dtype=jnp.float32), N_cafes=len(m.df.cafe.unique()), afternoon=jnp.array(m.df.afternoon.values, dtype=jnp.float32),)# Define model ------------------------------------------------def model(cafe, wait, N_cafes, afternoon): sigma = m.dist.exponential(1, name="sigma") varying_intercept, varying_slope = m.effects.varying_effects( N_vars=1, N_group=N_cafes, group_id=cafe, group_name="cafe", centered=False ) mu = varying_intercept + varying_slope * afternoon m.dist.normal(mu, sigma, obs=wait)# Run sampler ------------------------------------------------m.fit(model, progress_bar=False)
jax.local_device_count 32
library(BayesForge)jnp = reticulate::import('jax.numpy')# Setup platform------------------------------------------------m=importBF(platform='cpu')# Import data ------------------------------------------------m$data(paste(system.file(package ="BayesForge"),"/data/Sim data multivariatenormal.csv", sep =''), sep=',')m$data_to_model(list('cafe', 'wait', 'afternoon'))# Import data ------------------------------------------------m$data(paste(system.file(package ="BayesForge"),"/data/Sim data multivariatenormal.csv", sep =''), sep=',')m$data_to_model(list('cafe', 'wait', 'afternoon'))m$data_on_model# Define model ------------------------------------------------model <-function(cafe, afternoon, wait, N_cafes =as.integer(20) ){ a =bf.dist.normal(5, 2, name ='a') b =bf.dist.normal(-1, 0.5, name ='b') sigma_cafe =bf.dist.exponential(1, shape=c(2), name ='sigma_cafe') sigma =bf.dist.exponential( 1, name ='sigma') Rho =bf.dist.lkj(as.integer(2), as.integer(2), name ='Rho') cov = jnp$outer(sigma_cafe, sigma_cafe) * Rho a_cafe_b_cafe =bf.dist.multivariate_normal( jnp$squeeze(jnp$stack(list(a, b))), cov, shape =c(N_cafes), name ='a_cafe') a_cafe = a_cafe_b_cafe[, 0] b_cafe = a_cafe_b_cafe[, 1] mu = a_cafe[cafe] + b_cafe[cafe] * afternoonbf.dist.normal(mu, sigma, obs=wait)}# Run MCMC ------------------------------------------------m$fit(model) # Optimize model parameters through MCMC sampling# Summary ------------------------------------------------m$summary() # Get posterior distribution
usingBayesForge# Setup device------------------------------------------------m =importBF(platform="cpu")# Import Data & Data Manipulation ------------------------------------------------# Importdata_path = m.load.sim_multivariate_normal(only_path =true)m.data(data_path, sep=',') m.data_on_model = pybuiltins.dict( cafe = jnp.array(m.df.cafe.values, dtype=jnp.int32), wait = jnp.array(m.df.wait.values, dtype=jnp.float32), N_cafes =length(m.df.cafe.unique()), afternoon = jnp.array(m.df.afternoon.values, dtype=jnp.float32))# Define model ------------------------------------------------@BFfunctionmodel(cafe, wait, N_cafes, afternoon) a = m.dist.normal(5, 2, name ="a") b = m.dist.normal(-1, 0.5, name ="b") sigma_cafe = m.dist.exponential(1, shape=(2,), name ="sigma_cafe") sigma = m.dist.exponential( 1, name ="sigma") Rho = m.dist.lkj(2, 2, name ="Rho") cov = jnp.outer(sigma_cafe, sigma_cafe) * Rho a_cafe_b_cafe = m.dist.multivariate_normal(jnp.stack([a, b]), cov, shape = [N_cafes], name ="a_b_cafe") a_cafe, b_cafe = a_cafe_b_cafe[:, 0], a_cafe_b_cafe[:, 1] mu = a_cafe[cafe] + b_cafe[cafe] * afternoon m.dist.normal(mu, sigma, obs=wait)end# Run mcmc ------------------------------------------------m.fit(model) # Optimize model parameters through MCMC sampling# Summary ------------------------------------------------m.summary() # Get posterior distributions
Mathematical Details
Centered parameterization
The Gaussian Mixture Model is a hierarchical model where each data point is generated from one of K distinct multivariate Gaussian distributions.
The varying intercepts (\alpha_k) and slopes (\beta_k) are modeled using a Multivariate Normal distribution:
\left(\begin{array}{cc} \bar{\alpha} \\ \bar{\beta} \end{array}\right) is a vector composed from concatenating a parameter for the global intercept and a parameter vector of the global slopes.
\varsigma is a vector giving the standard deviation of the random effects for the intercept and slopes across groups.
\Omega is the correlation matrix.
Non-centered parameterization
For computational reasons, it is often better to implement a non-centered parameterization 🛈 that is equivalent to the Multivariate Normal distribution approach:
\sigma_\alpha \sim \text{Exponential}(1) is the prior standard deviation among intercepts.
\sigma_\beta \sim \text{Exponential}(1) is the prior standard deviation among slopes.
L \sim \text{LKJcorr}(\eta) is the a cholesky factor of the correlation matrix matrix using the Cholesky Factor 🛈
Multivariate Model with One Random Slope for Each Variable
We can apply a multivariate model similarly to Chapter 2. In this case, we apply the same principle, but with a covariance matrix with a dimension equal to the number of varying slopes we define. For example, if we want to generate random slopes for i observations in a model with two independent variables X_1 and X_2, we can define the formula as follows: